# ---------------------------------------------------------
# TUM - Technichal University of Munich
#
# Original Authors: Taylor Jones
#                   Magdalena Altmann
# Further Authors:  Adrian Wenzel
# Date: 2020-05-06
# Purpose: This source code intends to re-estimate CH4
#          emissions in Munich using a top-down approach.
#          This is the main source file.
# Input: Prior emission estimates (inventories)
#        XCH4 observations (column-averaged mole fractions)
#        Footprints (Sensitivity of receptor to emissions)
# Comment: This source code has been developed in 2019,
#          however the date above indicates the first
#          submit to version control.
# ---------------------------------------------------------


graphics.off()  # clear plots
# dev.off()  # clear plots

library(reshape2)
require('raster')
require('leaflet')
require('proj4')
require('ncdf4')
require('xts')
require('rasterVis')
require('gridExtra')
require('ggplot2')
require('viridisLite')
require('MCMCpack')
require('Metrics')
require('rsq')
require('base64')
require('base64enc')

t1 <- Sys.time()

# get directory of source file and set it as working directory
#setwd(path_inverse_modeling_engine)  # set path as working directory

source(paste0(path_inverse_modeling_engine,'plotting.R'))             # has all the plotting functions
source(paste0(path_inverse_modeling_engine,'load_inventories.R'))     # has functions to load different inventories into raster stacks
source(paste0(path_inverse_modeling_engine,'inversion_functions.R'))  # The rest of the functions



#Load the Prior: ------------------------------------------------------------------------------------------------------------------------
#date         <- "20180827"
date         <- dates[1]
# TODO: Switch for different prior file formats
foot_file    <- paste0(path_campaign,"/foot/",prefix,"_",date,foot_suffix) # footprint files (needed here just for the projection)

bars_df_total   <- data.frame()

for(inventory_name in names(inventory_files)){
  print( inventory_name) 
  inventory_file <- inventory_files[inventory_name]
  date         <- dates[1]
  foot_file    <- paste0(path_campaign,"/foot/",prefix,"_",date,foot_suffix) # footprint files (needed here just for the projection)
  loaded_prior <- load_inv2(inventory_file,foot_file=foot_file)         # read the prior maps and reproject it to the footprint projection.
  sectors      <- loaded_prior$sectors                                  # list out the IPCC sectors listed in those maps
  ras_stack    <- loaded_prior$ras/2  
  prior_totals <- cellStats(ras_stack*(raster::area(ras_stack)),"sum") #*to_ggyr    # Add up all the emissions and convert to GG/yr
  names(prior_totals) <- sectors                                        # Name these totals with the sector names


  #I'll need these later:
  out_df_total <- data.frame()
  bk_df_total  <- data.frame()
  bars_df      <- data.frame()
  y_df_total   <- data.frame()
  bim_df_total <- data.frame()

  r2_enhancement_total   <- data.frame()
  r2_concentration_total <- data.frame()

  x_hat_total            <- data.frame()
  
  te_total <- data.frame()

  # Loop through the dates: ----------------------------------------------------------------------------------------------------------------
  for( date in dates){
    message(date)
    foot_file    <- paste0(path_campaign,"/foot/",prefix,"_",date,foot_suffix) # footprint files (needed here just for the projection)
    obs_file     <- paste0(path_campaign,"/obs/",prefix,date,".csv")                         # em27 observation files
    y_df         <- create_y_df(foot_file,obs_file,ems,sectors,ras_stack)  # This loads the observations and the footprints, and multiplies the fp x prior maps.
    bim_df       <- load_bims(foot_file,ems)        #load the background influences
  
    if( use_transport_error){
      in_file <- paste0(t_error_directory, prefix, "_", date, "_", t_error_suffix)
      te_df   <- readRDS(in_file)
      te_df$var    <- te_df$sd*te_df$sd*te_df$scaling_factor
      te_df$mu     <- te_df$mu*te_df$scaling_factor
      te_df_var    <- aggregate(  var ~ run_time+designator , te_df, sum )
      te_df_thresh <- aggregate( . ~ run_time+designator, te_df, mean )
      te_df        <- aggregate( . ~ run_time+designator, te_df, sum)
      te_df$sd <- sqrt(te_df_var$var)
      te_df$over_thresh <- te_df_thresh$lower_threshold
      te_df <- te_df[ , c(1,2,3,4,9),]
      names(te_df) <- c("recep", "em","te_mean","te_sd","te_over_thresh")
      te_df$recep <- te_df$recep - 4*60*60  #wrong time zone....thanks posix!
      y_df$te_over_thresh <- 0
      y_df$te_sd <- 0
    
      for(em in ems){
        sub_df <- te_df[ te_df$em == em, ]
        y_df$te_over_thresh[y_df$em == em] <- approx( sub_df$recep, sub_df$te_over_thresh, y_df$recep[y_df$em == em])$y  
        y_df$te_sd[y_df$em == em] <- approx( sub_df$recep, sub_df$te_sd, y_df$recep[y_df$em == em])$y
      }
      y_df$te_sd[ is.na(y_df$te_sd) ] <- .1
      y_df$te_over_thresh[ is.na(y_df$te_over_thresh) ] <- .1
      
      #y_df <- y_df[ y_df$te_over_thresh < .5, ] #remove footprints with likely point source influence 
    
    }else{
      y_df$te_sd <- sigma_obs_prior
      y_df$te_over_thresh <- 0
    }
  
    #save footprint (just for plotting):
    #fp        <- stack(foot_file,varname="mb foot")   #choose station for footprint plotting
    #ras_fp       <- fp[[1]]                           #sets which time should be used 1 = 8:00 UTC, 2 = 8:15 UTC, ...
 
  # #   #This is the main function that merges the footprints, BIMs, and observations and performs the inversion:
    m <- merge_and_invert( y_df = y_df,
                          bim_df = bim_df,
                          sigma_bkgd_prior = sigma_bkgd_prior,
                          sigma_sector_priors = sigma_sector_priors,
                          sigma_obs_prior = sigma_obs_prior,
                          t_back = t_back,
                          bkgd_prior = bkgd_prior,
                          use_transport_error = use_transport_error
                       )
    m$df$date <- date
   
    if( error_method == "bootstrap"){
      b <- do_bootstrap( df = m$df,
          sector = sector_to_bootstrap,
          n_times = bootstrap_reps,
          bs_x_err = bs_x_err,
          bs_y_err = bs_y_err
      )

      bars_df <- rbind( bars_df , data.frame( "type" = date,
                                              "sector" = sector_to_bootstrap,
                                              "emission" = mean(b)*prior_totals[1], 
                                              "lower" = quantile(b,0.05)*prior_totals[1], 
                                              "upper" = quantile(b,0.95)*prior_totals[1],
                                              "prior" = inventory_name)
      )    
    }
    
    if( error_method == "posterior"){
      sigma_obs_pos <- sqrt( diag( m$S_pos )[1:length(sectors)] ) 
      lower <- m$x_hat[1] - 1.645*sigma_obs_pos[1]
      upper <- m$x_hat[1] + 1.645*sigma_obs_pos[1]

      bars_df <- rbind( bars_df , data.frame(  "type" = date,
                                             "sector" = sector_to_bootstrap,
                                             "emission" = m$x_hat[1]*prior_totals[1], 
                                             "lower" = lower*prior_totals[1], 
                                             "upper" = upper*prior_totals[1],
                                             "prior" = inventory_name)
      )
    }    
    
#calculate r^2 and rmse of the total concentrations and only the enhancement
    
    #cor_enhancement <- cor(m$df$ob-m$df$y_bk,m$df$y_hat-m$df$y_bk)
    #cor_concentration <- cor(m$df$ob,m$df$y_hat)                        
    
    rsq <- function(x,y) summary(lm(y~x))$r.squared
    r2_enhancement    <- rsq(m$df$ob-m$df$y_bk,m$df$y_hat-m$df$y_bk)
    r2_concentration  <- rsq(m$df$ob,m$df$y_hat)
    
    #rmse_enhancement   <- rmse(m$df$ob-m$df$y_bk,m$df$y_hat-m$df$y_bk)
    #rmse_concentration <- rmse(m$df$ob,m$df$y_hat)

#   #stack up daily data frames:
    y_df_total   <- rbind( y_df_total, y_df )
    bim_df_total <- rbind( bim_df_total, bim_df)
    out_df_total <- rbind( out_df_total, m$df)
    bk_df_total  <- rbind( bk_df_total,  m$bk_df)
  
    #bars_total   <- rbind(bars_total, bars_df)
  
    r2_enhancement_total   <- rbind(r2_enhancement_total, r2_enhancement)
    r2_concentration_total <- rbind(r2_concentration_total, r2_concentration)
  
  for(i in seq(1,length(sectors))){
    x_hat_total            <- rbind(x_hat_total, m$x_hat[i])
  }
  

  for(i in seq(1,length(sectors))){
    # message(paste0("  ", sectors[i], " scaling factor: ", m$x_hat[i]))
    message(paste0("  scaling factor (sector: '", sectors[i], "'): ", m$x_hat[i]))
  }  
  
} 

  result <- cbind(dates, x_hat_total, r2_concentration_total, r2_enhancement_total)
  names(result) <- c("date", "x_hat", "R2 conc.", "R2 enh.")

 
  # Inversion using all of the days:

  if (invert_all_days == TRUE){
    m <- merge_and_invert( y_df = y_df_total,
                         bim_df = bim_df_total,
                         sigma_bkgd_prior = sigma_bkgd_prior,
                         sigma_sector_priors = sigma_sector_priors,
                         sigma_obs_prior = sigma_obs_prior,
                         t_back = t_back,
                         bkgd_prior = bkgd_prior
    )
  
    if( error_method == "bootstrap"){
    
      b <- do_bootstrap( df = m$df,
                       sector = sector_to_bootstrap,
                       n_times = bootstrap_reps,
                       bs_x_err = bs_x_err,
                       bs_y_err = bs_y_err
      )
    
      bars_df <- rbind( bars_df , data.frame( "type" = "ALL",
                                            "sector" = sector_to_bootstrap,
                                            "emission" = mean(b)*prior_totals[1], 
                                            "lower" = quantile(b,0.05)*prior_totals[1], 
                                            "upper" = quantile(b,0.95)*prior_totals[1],
                                            "prior" = inventory_name)
      )  
    }

    if( error_method == "posterior"){
      pos <- sqrt( diag(m$S_pos) )[1:length(sectors)]
      sigma_obs_pos <- sqrt( diag( m$S_pos )[1:length(sectors)] ) 
      lower <- m$x_hat[1] - 1.645*sigma_obs_pos[1]
      upper <- m$x_hat[1] + 1.645*sigma_obs_pos[1]
      bars_df <- rbind( bars_df , data.frame( "type" = "ALL" ,
                                          "sector" = sector_to_bootstrap,
                                          "emission" = m$x_hat[1]*prior_totals[1], 
                                          "lower" = lower*prior_totals[1] , 
                                          "upper" = upper*prior_totals[1],
                                          "prior" = inventory_name)
      )
    }
  
  # Look at the output: -----------------------------------
    message("All days combined:")
    for(i in seq(1,length(sectors))){
    # message(paste0("  ", sectors[i], " scaling factor: ", m$x_hat[i]))
      message(paste0("  scaling factor (sector: '", sectors[i], "'): ", m$x_hat[i]))
    }
  
    rsq <- function(x,y) summary(lm(y~x))$r.squared
    r2_enhancement    <- rsq(m$df$ob-m$df$y_bk,m$df$y_hat-m$df$y_bk)
    r2_concentration  <- rsq(m$df$ob,m$df$y_hat)
  
  }


  # plotting: ----------------------------------------------
  plt_obs(y_df)                 #plot observations
  plt_stacked(y_df)             #plot the forward model result (prior expecting contributions)
  plt_bars(bars_df)             #plot posterior emission bars
  plt_corr(m$df)                #plot correlation model observations (total concentration)
  plt_corr_no_bkgd(m$df)        #plot correlation model observations (enhancement)

  plt_bk_pos( bk_df_total ,log=T) #+ ylim(0,.5) #plot posterior bkgd error 

#plt_y_hat(m$df)                  #brauch ich nicht

#plt_result(m$df,m$bk_df)         #brauch ich nicht

#plt_result(out_df_total,bk_df_total)  #inversion result + obs + background
#plt_bk_error(m$df)

  plt_bkgd(m$df,m$bk_df)        #plot background

#plt_BIM(bim_df[which(bim_df$em == "ki"),])              #plot BIM
  plt_BIM(bim_df)  


#plots with bkgd station:-------------------------------------
#plt_bkgd_with_stat(m$df, m$bk_df, y_df_ma, b_df)
#plt_bkgd_mod_obs(m$df, m$bk_df, y_df_ma)



#plot inventory map ------------------------------------------

#ras <- ras_stack[[1]]                        # get a rasterlayer from the rasterstack
#plt_inventory_map(ras,ems,em_lats,em_lons)   # plot the inventory map with em locations

  bars_df_total <- rbind( bars_df_total, bars_df)
  
}



#plt_b(b)
t2 <- Sys.time()
t_elapsed <- t2 - t1
message("Time elapsed: ", round(as.numeric(t_elapsed, unit="secs"), digits=2), " s")
